//===----------------------------------------------------------------------===//
//
//                         BusTub
//
// hash_join_executor.cpp
//
// Identification: src/execution/hash_join_executor.cpp
//
// Copyright (c) 2015-2021, Carnegie Mellon University Database Group
//
//===----------------------------------------------------------------------===//

#include "execution/executors/hash_join_executor.h"
#include "type/value_factory.h"

namespace bustub {

HashJoinExecutor::HashJoinExecutor(ExecutorContext *exec_ctx, const HashJoinPlanNode *plan,
                                   std::unique_ptr<AbstractExecutor> &&left_child,
                                   std::unique_ptr<AbstractExecutor> &&right_child)
    : AbstractExecutor(exec_ctx),
      plan_(plan),
      left_child_(std::move(left_child)),
      right_child_(std::move(right_child)) {
  if (plan->GetJoinType() != JoinType::LEFT && plan->GetJoinType() != JoinType::INNER) {
    // Note for 2023 Spring: You ONLY need to implement left join and inner join.
    throw bustub::NotImplementedException(fmt::format("join type {} not supported", plan->GetJoinType()));
  }
}

void HashJoinExecutor::Init() {
  LOG_DEBUG("Init");
  left_child_->Init();
  right_child_->Init();
  hjt_.Clear();
  res_.clear();
  cursor_ = 0;
  Tuple tuple;
  RID rid;
  if (plan_->RightJoinKeyExpressions().empty()) {
    LOG_DEBUG("Init1");
    return;
  }

  while (right_child_->Next(&tuple, &rid)) {
    hjt_.InsertCombine(MakeRightJoinKey(&tuple), MakeJoinValue(&tuple));
  }

  if (plan_->GetJoinType() == JoinType::INNER) {
    GenForInnerJoin();
    return;
  }
  GenForLeftJoin();
}

auto HashJoinExecutor::Next(Tuple *tuple, RID *rid) -> bool {
  if (cursor_ >= res_.size()) {
    return false;
  }
  *tuple = res_[cursor_];
  cursor_++;
  return true;
}

auto HashJoinExecutor::GenForInnerJoin() -> void {
  LOG_DEBUG("NextForInnerJoin");
  Tuple left_tuple;
  RID left_rid;
  while (left_child_->Next(&left_tuple, &left_rid)) {
    auto join_key = MakeLeftJoinKey(&left_tuple);
    if (hjt_.Count(join_key) != 0) {
      LOG_DEBUG("hit_key=%s", join_key.ToString().c_str());
      auto range = hjt_.ht_.equal_range(join_key);

      for (auto it = range.first; it != range.second; it++) {
        int left_col = left_child_->GetOutputSchema().GetColumnCount();
        std::vector<Value> values;
        values.reserve(left_col);
        for (int i = 0; i < left_col; i++) {  // 左表的值
          values.push_back(left_tuple.GetValue(&(left_child_->GetOutputSchema()), static_cast<uint32_t>(i)));
        }
        values.insert(values.end(), it->second.join_cols_.begin(), it->second.join_cols_.end());
        LOG_DEBUG("NextForInnerJoin, left_col=%d, right_col=%ld", left_col, it->second.join_cols_.size());
        auto tuple = Tuple(values, &(GetOutputSchema()));
        LOG_DEBUG("tuple=%s", tuple.ToString(&GetOutputSchema()).c_str());
        res_.push_back(tuple);
      }
    }
  }
}

auto HashJoinExecutor::GenForLeftJoin() -> void {
  Tuple left_tuple;
  RID left_rid;
  while (left_child_->Next(&left_tuple, &left_rid)) {
    auto join_key = MakeLeftJoinKey(&left_tuple);
    if (hjt_.Count(join_key) != 0) {
      LOG_DEBUG("hit_key=%s", join_key.ToString().c_str());
      auto range = hjt_.ht_.equal_range(join_key);
      for (auto it = range.first; it != range.second; it++) {
        int left_col = left_child_->GetOutputSchema().GetColumnCount();
        std::vector<Value> values;
        values.reserve(left_col);
        for (int i = 0; i < left_col; i++) {  // 左表的值
          values.push_back(left_tuple.GetValue(&(left_child_->GetOutputSchema()), static_cast<uint32_t>(i)));
        }
        values.insert(values.end(), it->second.join_cols_.begin(), it->second.join_cols_.end());
        LOG_DEBUG("NextForInnerJoin, left_col=%d, right_col=%ld", left_col, it->second.join_cols_.size());
        auto tuple = Tuple(values, &(GetOutputSchema()));
        LOG_DEBUG("tuple=%s", tuple.ToString(&GetOutputSchema()).c_str());
        res_.push_back(tuple);
      }
    } else {
      int left_col = left_child_->GetOutputSchema().GetColumnCount();
      std::vector<Value> values;
      values.reserve(left_col);
      for (int i = 0; i < left_col; i++) {  // 左表的值
        values.push_back(left_tuple.GetValue(&(left_child_->GetOutputSchema()), static_cast<uint32_t>(i)));
      }
      // NULL
      std::vector<Value> right_values;
      int right_col = right_child_->GetOutputSchema().GetColumnCount();
      right_values.reserve(right_col);
      for (int i = 0; i < right_col; i++) {  // 右表的值
        right_values.push_back(
            ValueFactory::GetNullValueByType(right_child_->GetOutputSchema().GetColumn(i).GetType()));
      }
      values.insert(values.end(), right_values.begin(), right_values.end());
      auto tuple = Tuple(values, &(GetOutputSchema()));
      res_.push_back(tuple);
    }
  }
}

}  // namespace bustub
